{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8087b53f-c12d-4a1c-a62c-e1ced4f197b6",
   "metadata": {},
   "source": [
    "# Instructions\n",
    "Please refer to the [infer_quickstart.ipynb](https://github.com/bowang-lab/ECG-FM/blob/main/notebooks/infer_quickstart.ipynb) notebook if you haven't already. This tutorial assumes you have already gone through the installation and model setup.\n",
    "\n",
    "This tutorial focuses on performing inference through the [fairseq_signals](https://github.com/Jwoo5/fairseq-signals) command-line functionality, which is useful for the large-scale computation and storage of results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43635fe2-aa31-4aed-8f19-da356d3c0177",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "\n",
    "root = os.path.dirname(os.getcwd())\n",
    "\n",
    "fairseq_signals_root = # TODO\n",
    "fairseq_signals_root = fairseq_signals_root.rstrip('/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "186a775d-42fb-46d5-9686-3e1e2424c2b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "label_def = pd.read_csv(\n",
    "    os.path.join(root, 'data/mimic_iv_ecg/labels/label_def.csv'),\n",
    "     index_col='name',\n",
    ")\n",
    "label_names = label_def.index\n",
    "label_names"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8a6e9a4e-c23b-4f9b-aa82-0a7831042dc3",
   "metadata": {},
   "source": [
    "## Data manifest"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75624573-d1c5-4ab6-8dab-f682bb1c349f",
   "metadata": {},
   "source": [
    "The segmented split must be saved with absolute file paths, so we will update the current relative file paths accordingly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5b99a8f-6eea-4c60-9c27-c2ef388c4a95",
   "metadata": {},
   "outputs": [],
   "source": [
    "segmented_split = pd.read_csv(\n",
    "    os.path.join(root, 'data/code_15/segmented_split_incomplete.csv'),\n",
    "    index_col='idx',\n",
    ")\n",
    "segmented_split['path'] = (root + '/data/code_15/segmented/') + segmented_split['path']\n",
    "segmented_split.to_csv(os.path.join(root, 'data/code_15/segmented_split.csv'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61710580-87fc-43fe-9d7a-0465ebed46ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert os.path.isfile(os.path.join(root, 'data/code_15/segmented_split.csv'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b71fabac-3150-433d-9fe8-c559ec71022a",
   "metadata": {},
   "source": [
    "Run the follow commands togenerate the `test.tsv` file used for inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30575700-2503-4052-a7a9-58adc26e634c",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"\"\"cd {fairseq_signals_root}/scripts/preprocess\n",
    "python manifests.py \\\\\n",
    "    --split_file_paths \"{root}/data/code_15/segmented_split.csv\" \\\\\n",
    "    --save_dir \"{root}/data/manifests/code_15_subset10/\"\n",
    "\"\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1f13589-c5ff-4a04-bd42-73bdfc807f72",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert os.path.isfile(os.path.join(root, 'data/manifests/code_15_subset10/test.tsv'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dfb153be-9d87-47f3-8855-3f5b947e53a3",
   "metadata": {},
   "source": [
    "# Inference\n",
    "\n",
    "Inside our environment, we can run the following command using hydra's command line interface to extract the logits/targets, as well as the precursor results needed to obtain the embeddings and saliency maps.\n",
    "\n",
    "The [embs.py](https://github.com/bowang-lab/ECG-FM/blob/main/scripts/embs.py) and [saliency.py](https://github.com/bowang-lab/ECG-FM/blob/main/scripts/saliency.py) scripts can then be used to convert the result precursors into a more final form. See the `infer_quickstart.ipynb` for visualization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fe363b7-8d1b-4831-91eb-9b326fed1593",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"\"\"fairseq-hydra-inference \\\n",
    "    task.data=\"{root}/data/manifests/code_15_subset10/\" \\\\\n",
    "    common_eval.path=\"{root}/ckpts/mimic_iv_ecg_finetuned.pt\" \\\\\n",
    "    common_eval.extract=[output,encoder_out,saliency]\n",
    "    common_eval.results_path=\"{root}/outputs\" \\\\\n",
    "    model.num_labels={len(label_names)} \\\\\n",
    "    dataset.valid_subset=test \\\n",
    "    dataset.batch_size=10 \\\n",
    "    dataset.num_workers=3 \\\n",
    "    dataset.disable_validation=false \\\n",
    "    distributed_training.distributed_world_size=1 \\\n",
    "    distributed_training.find_unused_parameters=True \\\n",
    "    --config-dir \"{root}/ckpts\" \\\\\n",
    "    --config-name mimic_iv_ecg_finetuned\n",
    "\"\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "033dfa89-d249-410b-b5ab-39737686b4b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert os.path.isfile(os.path.join(root, 'outputs/outputs_test.npy'))\n",
    "assert os.path.isfile(os.path.join(root, 'outputs/outputs_test_header.pkl'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "635f2f4b-fc76-4b09-87f4-82265e13b6a1",
   "metadata": {},
   "source": [
    "### Loading logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f46f5785-8819-4737-b351-5279df9eac15",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "from fairseq_signals.utils.store import MemmapReader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83e228ae-a72b-492d-ba0c-be2a667c6a7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the array of computed logits\n",
    "logits = MemmapReader.from_header(\n",
    "    os.path.join(root, 'outputs/outputs_test.npy')\n",
    ")[:]\n",
    "logits.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2630a917-ec9c-41f3-8587-9d94625c4d7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Construct predictions from logits\n",
    "pred = pd.DataFrame(\n",
    "    torch.sigmoid(torch.tensor(logits)).numpy(),\n",
    "    columns=label_names,\n",
    ")\n",
    "\n",
    "# Join in sample information\n",
    "pred = segmented_split.reset_index().join(pred, how='left').set_index('idx')\n",
    "pred"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8289fa5-508b-4499-9801-6b02273d1eb9",
   "metadata": {},
   "source": [
    "# Pretrained embeddings\n",
    "If looking to obtain pretrained embeddings (e.g., for use as a feature set or for linear probing), the simplest way is to run `fairseq-hydra-train` to transform a pretrained model into a finetuning model format which can be ran through `fairseq-hydra-validate` with `common_eval.extract=[encoder_out]`. Include the following arguments to ensure no training actually occurs:\n",
    "```\n",
    "    optimization.lr=[1e-25] \\\n",
    "    optimization.max_update=1 \\\n",
    "    checkpoint.save_interval_updates=1 \\\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f3c9465-bd9f-4d72-aeb7-23bc22f4bc9a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fairseq",
   "language": "python",
   "name": "fairseq"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
